#!/usr/bin/env python

'''This script creates multiple commands according to the variations specified
   on the command line. These commands will be packages into a shell script
   and a SLURM job. For example:

     $ jexplode foo 4 cmd { -a } -b { c , d e } @dir/@id

   will produce four variations:

     cmd -a -b c foo/0
     cmd -a -b d e foo/1
     cmd -b c foo/2
     cmd -b d e foo/3

   That is:

     * The first argument is a directory in which to place the job.

     * The second argument is the number of CPUs to allocate per task.

     * Single argument groups surrounded by { and } have two variations:
       present and absent.

     * Multiple argument groups, separated by commas and surrounded by { and }
       will have one variation for each group.

     * The strings @dir and @id are replaced by the job directory and the
       job ID, respectively. This replacement occurs *after* the commands are
       expanded.'''

import errno
import itertools
import os
import sys

import testable


def main():
   cmdline = ' '.join(sys.argv)
   sys.argv.pop(0)  # discard command name
   dir_ = os.path.abspath(sys.argv.pop(0))
   cpu_ct = int(sys.argv.pop(0))
   if (len(sys.argv) == 0):
      fatal('no commands, aborting')
   cs = commands_build(dir_, sys.argv)
   make_dirs(dir_, len(cs))
   make_slurm(dir_, cpu_ct, cs, cmdline)
   make_tasklist(dir_, cs)
   make_script(dir_, cs, cmdline)
   print 'created job with %d commands in %s' % (len(cs), dir_)

def commands_build(dir_, args):
   '''For example:

        >>> commands_build('foo', 'a { b , c } @dir/@id'.split())
        ['a b foo/0', 'a c foo/1']'''
   cs_in = explode(args)
   cs_out = []
   def sub(cmd):
      return cmd.replace('@dir', dir_).replace('@id', str(i))
   for i in xrange(len(cs_in)):
      cs_out.append(sub(' '.join(cs_in[i])))
   return cs_out

def explode(tokens):
   '''Given a list of tokens with special open, close, and delimiter
      repetition tokens, return a list of lists of tokens, one for each
      repetition. For example:

        >>> [' '.join(i) for i in explode('cmd { -a } -b { c , d e }'.split())]
        ['cmd -a -b c', 'cmd -b c', 'cmd -a -b d e', 'cmd -b d e']'''
   # FIXME: This is hellaciously more complicated than I wanted. Maybe there's
   # a parsing library I can call instead? (I wonder if the PHP interpreter
   # looks like this inside.)
   tokens = tokens[:]
   unsplit = []
   choices = []
   while (len(tokens) > 0):
      t = tokens.pop(0)
      if (t == '{'):
         choices.append([])
      elif (t == '}'):
         tails = explode(tokens)
         if (len(choices) == 0):
            raise ValueError('close brace without open')
         if (len(choices) == 1 and len(choices[0]) > 0):
            choices.append([None])
         if (len(tails) > 0):
            new_choices = []
            for tail in tails:
               for choice in choices:
                  new_choices.append(choice + tail)
            choices = new_choices
         break
      elif (t == ','):
         choices.append([])
      else:
         if (len(choices) > 0):
            choices[-1].append(t)
         else:
            unsplit.append(t)
   result = []
   if (len(choices) == 0):
      result = [unsplit]
   else:
      for choice in choices:
         result.append(unsplit + choice)
   return [filter(lambda j: j is not None, i) for i in result]

def fatal(msg):
   print >>sys.stderr, msg
   sys.exit(1)

def make_dirs(path, cmd_ct):
   try:
      os.makedirs(path)
      for i in xrange(cmd_ct):
         os.mkdir('%s/%d' % (path, i))
   except OSError, x:
      fatal("can't make directories in %s: %s" % (path, str(x)))

def make_script(dir_, cs, cmdline):
   fp = open('%s/serial.sh' % (dir_), 'wt')
   print >>fp, '''\
#!/bin/sh

# jexplode command line:
# %s

set -e
cd %s
''' % (cmdline, dir_)
   for i in xrange(len(cs)):
      print >>fp, '%s >& %d/out' % (cs[i], i)
   os.fchmod(fp.fileno(), 0770)

def make_slurm(dir_, cpus_per_task, cs, cmdline):
   fp = open('%s/slurm_job' % (dir_), 'wt')
   email = os.environ['USER'] + '@lanl.gov'
   cmd_ct = len(cs)
   dir_tail = os.path.basename(dir_)
   print >>fp, '''\
#!/bin/sh

# jexplode command line:
# %(cmdline)s

#SBATCH --cpus-per-task=%(cpus_per_task)d
#SBATCH --error=%(dir_)s/stderr
#SBATCH --exclusive
#SBATCH --job-name=%(dir_tail)s
#SBATCH --mail-type=ALL
#SBATCH --mail-user=%(email)s
#SBATCH --ntasks=%(cmd_ct)d
#SBATCH --output=%(dir_)s/stdout

cd %(dir_)s
srun --cpus-per-task=%(cpus_per_task)d --ntasks=%(cmd_ct)d --distribution=cyclic --output=%%t/out --multi-prog slurm_tasks
''' % locals()
   os.fchmod(fp.fileno(), 0770)

def make_tasklist(dir_, cs):
   fp = open('%s/slurm_tasks' % (dir_), 'wt')
   for i in xrange(len(cs)):
      jobfn = '%s/%d/job.sh' % (dir_, i)
      print >>fp, '%d sh %s' % (i, jobfn)
      print >>open(jobfn, 'wt'), cs[i]


if (__name__ == '__main__' and not testable.do_script_tests()):
   main()

testable.register('''

   # simple cases
   >>> explode('a'.split())
   [['a']]
   >>> explode('{ a }'.split())
   [['a'], []]
   >>> explode('{ a b }'.split())
   [['a', 'b'], []]
   >>> explode('{ a , b }'.split())
   [['a'], ['b']]
   >>> explode('a { b , c }'.split())
   [['a', 'b'], ['a', 'c']]
   >>> explode('a { b , c } d'.split())
   [['a', 'b', 'd'], ['a', 'c', 'd']]
   >>> explode('{ }'.split())
   [[]]
   >>> explode('{ , }'.split())
   [[], []]

   # explode() should give ValueError for unmatched close brace
   >>> explode('}'.split())
   Traceback (most recent call last):
     ...
   ValueError: close brace without open

''')
